import matplotlib.pyplot as mp
import pickle
from PIL import Image
import tensorflow.keras as keras
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.models import *
from tensorflow.python.keras.optimizers import *
from tensorflow.python.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
import glob, pickle
import random
import time
import tensorflow as tf
import tensorflow.keras.backend as K

NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
CAPTCHA_CHARSET = NUMBER  # 使用数字字符集生成验证码
CAPTCHA_LEN = 4  # 验证码长度
CAPTCHA_HEIGHT = 60  # 验证码高度
CAPTCHA_WIDTH = 160  # 验证码长度
TRAIN_DATA_DIR = './images/'  # 验证码训练数据集路径
import numpy as np
import tensorflow.gfile as gfile

TEST_DATA_DIR = './test/'


# 适配Keras图像数据格式通道
def fit_keras_channels(batch, rows=CAPTCHA_HEIGHT, cols=CAPTCHA_WIDTH):
    if K.image_data_format() == 'channel first':
        batch = batch.reshape(batch.shape[0], 1, rows, cols)
        input_shape = (1, rows, cols)
    else:
        batch = batch.reshape(batch.shape[0], rows, cols, 1)
        input_shape = (rows, cols, 1)
    return batch, input_shape


# 灰度化
def rgb2gray(image):
    return np.dot(image[..., :3], [0.299, 0.587, 0.114])


# one-hot编码
def text2vec(text, length=CAPTCHA_LEN, charset=CAPTCHA_CHARSET):
    text_len = len(text)
    # 验证码长度校验
    if text_len != length:
        raise ValueError(
            "输入字符长度为{}，与所需验证码长度{}不相符".format(text_len, length))
    vec = np.zeros(length * len(charset))
    for i in range(length):
        vec[charset.index(text[i]) + i * len(charset)] = 1
    return vec


# 向量转为字符
def vec2text(vector):
    if not isinstance(vector, np.ndarray):
        vector = np.asarray(vector)
    vector = np.reshape(vector, [CAPTCHA_LEN, -1])
    text = ''
    for item in vector:
        text += CAPTCHA_CHARSET[np.argmax(item)]
    return text


class MCDropout(Dropout):
    def call(self, inputs):
        return super().call(inputs, training=True)


if __name__ == '__main__':
    history_file = './history/train_demo/ captcha_adam_binary_crossentropy_bs_100_epochs_10.history'
    model_file = './model/train_demo/ captcha_adam_binary_crossentropy_bs_100_epochs_10.h5'

    # 读取测试集数据，处理图像和标签
    X_test, Y_test = [], []
    # 读取测试集数据
    for filename in glob.glob(TEST_DATA_DIR + '*.jpg'):
        X_test.append(np.array(Image.open(filename)))
        Y_test.append(filename.lstrip(TEST_DATA_DIR + '\\').rstrip('.jpg'))
    # 处理图像
    X_test = np.array(X_test, dtype=np.float32)
    X_test = rgb2gray(X_test) / 255
    X_test, _ = fit_keras_channels(X_test)
    # 处理标签
    Y_test = list(Y_test)
    for i in range(len(Y_test)):
        Y_test[i] = text2vec(Y_test[i])
    Y_test = np.asarray(Y_test)

    # 加载模型
    # model = keras.models.load_model('./model/train_demo/ captcha_categorical_crossentropy_bs_100_epochs_40.h5')
    model = keras.models.load_model('./model/cnn_best_vgg_0.99.h5')
    # 预测单张样例
    print(vec2text(Y_test[22]))
    yy = model.predict(X_test[22].reshape(1, 60, 160, 1))
    print(vec2text(yy))

    # 预测测试集效果
    count = 0
    for i in range(len(Y_test)):
        pred = vec2text(model.predict(X_test[i].reshape(1, 60, 160, 1)))
        real = vec2text(Y_test[i])
        if (pred == real):
            count = count + 1
    print("样本数：{}，正确数:{}，错误数:{}，准确率:{}".format(len(Y_test), count, len(Y_test) - count, count / len(Y_test)))

